# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore import nn
from mindspore import ops

from accspeed.ops.dropout_mask_decoder.dropout_mask_decoder_impl import \
    get_dropout_mask_decoder


class DropoutMaskDecoder(nn.Cell):
    def __init__(self, bsz):
        super(DropoutMaskDecoder, self).__init__()
        self.mask_decoder = get_dropout_mask_decoder(bsz)
        self.tensor_sz_op = ops.Size()

    def construct(self, input_mask):
        # i_mask_bytes = self.tensor_sz_op(input_mask)
        # i_mask_bytes_aligned = (i_mask_bytes + 127) // 128 * 128
        # pad_len = i_mask_bytes_aligned - i_mask_bytes
        # if pad_len > 0:
        #     input_mask = ops.pad(input_mask, [0, pad_len], mode='constant', value=0)

        return self.mask_decoder(input_mask)

    def shard(self, in_strategy=None, out_strategy=None):
        """Distributed configuration of DropoutMaskDecoder
        :param in_strategy: Describe the split strategy of operator input. Default: None.
        :param out_strategy: Describe the split strategy of operator output, it is only for certain
                            operators, such as MatMul. Default: None.
        :return:
        """
        self.mask_decoder.shard(in_strategy)
        dp = in_strategy[0][0]
        mp = in_strategy[0][1]
        self.mask_decoder.add_prim_attr("dev_matrix_shape", [dp, mp])
        self.mask_decoder.add_prim_attr("inputs_tensor_map", [[1, 0], ])
        self.mask_decoder.add_prim_attr("outputs_tensor_map", [[1, 0], ])

        self.mask_decoder.add_prim_attr("as_loss_divisor", 0)
        self.mask_decoder.add_prim_attr("empty_mirror_ops", 0)
